{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.tabular.core import *\n",
    "from fastai.tabular.model import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.tabular.data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp tabular.learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular learner\n",
    "\n",
    "> The function to immediately get a `Learner` ready to train for tabular data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main function you probably want to use in this module is `tabular_learner`. It will automatically create a `TabularModel` suitable for your data and infer the right loss function. See the [tabular tutorial](http://docs.fast.ai/tutorial.tabular) for an example of use in context."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TabularLearner(Learner):\n",
    "    \"`Learner` for tabular data\"\n",
    "    def predict(self, row):\n",
    "        \"Predict on a Pandas Series\"\n",
    "        dl = self.dls.test_dl(row.to_frame().T)\n",
    "        dl.dataset.conts = dl.dataset.conts.astype(np.float32)\n",
    "        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n",
    "        b = (*tuplify(inp),*tuplify(dec_preds))\n",
    "        full_dec = self.dls.decode(b)\n",
    "        return full_dec,dec_preds[0],preds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"TabularLearner\" class=\"doc_header\"><code>class</code> <code>TabularLearner</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>TabularLearner</code>(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*) :: `Learner`\n",
       "\n",
       "`Learner` for tabular data"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TabularLearner, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It works exactly as a normal `Learner`, the only difference is that it implements a `predict` method specific to work on a row of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(Learner.__init__)\n",
    "def tabular_learner(dls, layers=None, emb_szs=None, config=None, n_out=None, y_range=None, **kwargs):\n",
    "    \"Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params.\"\n",
    "    if config is None: config = tabular_config()\n",
    "    if layers is None: layers = [200,100]\n",
    "    to = dls.train_ds\n",
    "    emb_szs = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')\n",
    "    model = TabularModel(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **config)\n",
    "    return TabularLearner(dls, model, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your data was built with fastai, you probably won't need to pass anything to `emb_szs` unless you want to change the default of the library (produced by `get_emb_sz`), same for `n_out` which should be automatically inferred. `layers` will default to `[200,100]` and is passed to `TabularModel` along with the `config`.\n",
    "\n",
    "Use `tabular_config` to create a `config` and customize the model used. There is just easy access to `y_range` because this argument is often used.\n",
    "\n",
    "All the other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')\n",
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [Categorify, FillMissing, Normalize]\n",
    "dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
    "                                 y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)\n",
    "learn = tabular_learner(dls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TabularLearner.predict\" class=\"doc_header\"><code>TabularLearner.predict</code><a href=\"__main__.py#L5\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TabularLearner.predict</code>(**`row`**)\n",
       "\n",
       "Predict on a Pandas Series"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TabularLearner.predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can pass in an individual row of data into our `TabularLearner`'s `predict` method. It's output is slightly different from the other `predict` methods, as this one will always return the input as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "row, clas, probs = learn.predict(df.iloc[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>education-num_na</th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education-num</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Private</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>#na#</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>49.0</td>\n",
       "      <td>101320.001685</td>\n",
       "      <td>12.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "row.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0), tensor([0.5264, 0.4736]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clas, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#test y_range is passed\n",
    "learn = tabular_learner(dls, y_range=(0,32))\n",
    "assert isinstance(learn.model.layers[-1], SigmoidRange)\n",
    "test_eq(learn.model.layers[-1].low, 0)\n",
    "test_eq(learn.model.layers[-1].high, 32)\n",
    "\n",
    "learn = tabular_learner(dls, config = tabular_config(y_range=(0,32)))\n",
    "assert isinstance(learn.model.layers[-1], SigmoidRange)\n",
    "test_eq(learn.model.layers[-1].low, 0)\n",
    "test_eq(learn.model.layers[-1].high, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:Tabular, y:Tabular, samples, outs, ctxs=None, max_n=10, **kwargs):\n",
    "    df = x.all_cols[:max_n]\n",
    "    for n in x.y_names: df[n+'_pred'] = y[n][:max_n].values\n",
    "    display_df(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
